--- title: Transformations for 3D medical images keywords: fastai sidebar: home_sidebar nb_path: "nbs/03_transforms.ipynb" ---
original = TensorDicom3D.create('../data/series/radiopaedia_10_85902_1.nii.gz')
mask = TensorMask3D.create('../data/masks/radiopaedia_10_85902_1.nii.gz')
original.show()
mask.show(add_to_existing = True, alpha = 0.25, cmap = 'jet')
Resize3D((10,50,50))(original, split_idx = 0).show()
Resize3D((10,50,50))(mask, split_idx = 0).show(add_to_existing = True, alpha = 0.25, cmap = 'jet')
Pad3D((10, 800, 800))(original).show()
In medical images, the left and right side often cannot be differentiated from each other (e.g. scans of the head, hand, knee, ...). Therfore the image orientation is stored in the image header, enabeling the viewer system to accuratly display the images. For deep learning, only the pixel array is extracted, so the header information is lost. When displaying only the pixel array, the images might already be displayed flipped or in inverted slice order. So, implementing vertical/horizontal flipping as well as flipping alongside the z axis can be used for data augmentation.
torch.stack((original, RandomFlip3D()(original, split_idx = 0),
RandomFlip3D()(original, split_idx = 0),
RandomFlip3D()(original, split_idx = 0))).show(nrow = 15)
Medical images should show no rotation, however with removal of the image file header, the pixel array might appear rotated when displayed and thus be introduced to the model rotated. Fruthermore, in some images the patients might be rotated to some degree. Thus rotation of 90 and 180° as well as substeps should be implemented.
torch.stack((original, RandomRotate3D()(original, split_idx = 0),
RandomRotate3D()(original, split_idx = 0), RandomRotate3D()(original, split_idx = 0))).show(nrow = 15)
Pytorch does not support rotation of 3D images, so some transformations need to be applied slicewise.
tmp1 = RandomRotate3DBy()(original, split_idx = 0)
tmp2 = RandomRotate3DBy(p=1., degrees=(10, 10, 45), axis=[-1, -2, -3])(original, split_idx = 0)
original.show(nrow = 15)
tmp1.show(nrow = 15)
tmp2.show(nrow =15)
Rotating by 90 (or 180 and 270) degrees should not be done via RandomRotate3DBy but by rotate_90_3d, as this is approximatly 28 times faster.
As the 3D array can be flipped by three sides, but should only be rotated along the z axis, this is not a complete dihedral group. Still multiple combinations of flipping and rotating should be implemented:
I am not sure if this is complete...
dihedral = RandomDihedral3D()
torch.stack((original, dihedral(original, split_idx = 0), dihedral(original, split_idx = 0),
dihedral(original, split_idx = 0),dihedral(original, split_idx = 0),
dihedral(original, split_idx = 0))).show(nrow=15)
A reasonable approach for 3D medical images would be a presizing to uniform but to large volume and subsequent random cropping to the target dimension. As most areas of interest are located centrally in the image/volume some cropping can always be applied.
Also random cropping should be applied after any rotation, that is not in 90/180/270 degrees, so that empty margins are cropped.
Crop = RandomCrop3D((10,50,50), (10,20,20), False)
torch.stack((Crop(original, split_idx = 0), Crop(original, split_idx = 0),
Crop(original, split_idx = 0), Crop(original, split_idx = 0))).show(nrow = 10)
im = Crop(original).resize_3d((10, 100, 100))
crop_mask = TensorMask3D(torch.ones(4, 100, 20)).pad_to((10, 100, 100))
crop_mask = crop_mask + crop_mask.rotate_90_3d()
crop_mask = torch.where(crop_mask == 0, 0, 1)
crop_mask2 = TensorMask3D(torch.ones(10, 100, 20)).pad_to((10, 100, 100))
crop_mask2 = crop_mask2 + crop_mask2.rotate_90_3d()
crop_mask2 = torch.where(crop_mask2 == 0, 1, 0)
crop_mask.show()
crop_mask2.show()
MaskErease(mask = crop_mask)(im).show()
MaskErease(mask = crop_mask2)(im).show()
im2 = TensorDicom3D.create('../data/example_grid.nii.gz')
im2 = im2.unsqueeze(0)
im2.show()
RandomPerspective3D(im.size(-1), p = 1.)(im2, split_idx=0).show()
RandomWarp3D(p=1, max_magnitude=0.5)(im2, split_idx = 0).show()
RandomWarp3D(p=1, max_magnitude=0.5)(im2, split_idx = 0).show()
RandomWarp3D(p=1, max_magnitude=0.5)(im2, split_idx = 0).show()
RandomSheer3D(p=1, max_magnitude=0.5)(im2, split_idx = 0).show()
RandomSheer3D(p=1, max_magnitude=0.5)(im2, split_idx = 0).show()
RandomSheer3D(p=1, max_magnitude=0.5)(im2, split_idx = 0).show()
RandomTrapezoid3D(p=1, max_magnitude=0.5)(im2, split_idx = 0).show()
RandomTrapezoid3D(p=1, max_magnitude=0.5)(im2, split_idx = 0).show()
RandomTrapezoid3D(p=1, max_magnitude=0.5)(im2, split_idx = 0).show()
from faimed3d.preprocess import mean_scale
im = Resize3D((10, 224, 224))(TensorDicom3D.create('../data/series/radiopaedia_10_85902_1.nii.gz')) # redefine for show_docs
Noise= RandomNoise3D(p=1)
RandomNoise3D(p=1)(im.mean_scale(), split_idx=0).show()
RandomNoise3D(p=1)(im.mean_scale(), split_idx=0).show()
RandomNoise3D(p=1)(im.mean_scale(), split_idx=0).show()
RandomNoise3D(p=1)(im.mean_scale(), split_idx=0).show()
RandomBlur3D(p=1., sigma = 10)(im, split_idx=0).show()
torch.stack((im.mean_scale(),
RandomBrightness3D(p=1., beta_range=[0.9, 1])(im.mean_scale(), split_idx = 0),
RandomBrightness3D(p=1., beta_range=[-0.9, -1])(im.mean_scale(), split_idx = 0))).show()
im.mean_scale().show()
RandomContrast3D(p=1.)(im.mean_scale(), split_idx = 0).show()
RandomContrast3D(p=1.)(im.mean_scale(), split_idx = 0).show()
def elastic_transform_3d(image, labels=None, alpha=4, sigma=35, bg_val=0.1):
"""
Elastic deformation of images as described in
Simard, Steinkraus and Platt, "Best Practices for
Convolutional Neural Networks applied to Visual
Document Analysis", in
Proc. of the International Conference on Document Analysis and
Recognition, 2003.
Modified from:
https://gist.github.com/chsasank/4d8f68caf01f041a6453e67fb30f8f5a
https://github.com/fcalvet/image_tools/blob/master/image_augmentation.py#L62
Modified to take 3D inputs
Deforms both the image and corresponding label file
image linear/trilinear interpolated
Label volumes nearest neighbour interpolated
"""
assert image.ndim == 3
shape = image.shape
dtype = image.dtype
# Define coordinate system
coords = np.arange(shape[0]), np.arange(shape[1]), np.arange(shape[2])
# Initialize interpolators
im_intrps = RegularGridInterpolator(coords, image,
method="linear",
bounds_error=False,
fill_value=bg_val)
# Get random elastic deformations
dx = gaussian_filter((np.random.rand(*shape) * 2 - 1), sigma,
mode="constant", cval=0.) * alpha
dy = gaussian_filter((np.random.rand(*shape) * 2 - 1), sigma,
mode="constant", cval=0.) * alpha
dz = gaussian_filter((np.random.rand(*shape) * 2 - 1), sigma,
mode="constant", cval=0.) * alpha
# Define sample points
x, y, z = np.mgrid[0:shape[0], 0:shape[1], 0:shape[2]]
indices = np.reshape(x + dx, (-1, 1)), \
np.reshape(y + dy, (-1, 1)), \
np.reshape(z + dz, (-1, 1))
# Interpolate 3D image image
image = np.empty(shape=image.shape, dtype=dtype)
image = im_intrps(indices).reshape(shape)
# Interpolate labels
if labels is not None:
lab_intrp = RegularGridInterpolator(coords, labels,
method="nearest",
bounds_error=False,
fill_value=0)
labels = lab_intrp(indices).reshape(shape).astype(labels.dtype)
return image, labels
return image
A good workflow would be to apply random crop to all images after one transformation. For this, the images should be presized to a size, just some pixels larger then desired, then transformed and then cropped to the final size. Using this approach empty space, which e.g. appears after RandomRotate3DBy will be cropped and not influence the accuracy of the model. One only has to be careful, that the region of interest, e.g. the prostate, will be in every cropped image.
Crop = RandomCrop3D((2,10,10), (1,2,2))
tfms = [RandomBrightness3D(), RandomContrast3D(), RandomWarp3D(), RandomDihedral3D(), RandomNoise3D(), RandomRotate3DBy()]
tfms = [Pipeline([RandomBrightness3D(p=1.), Crop], split_idx = 0),
Pipeline([RandomContrast3D(p=1.), Crop], split_idx = 0),
Pipeline([RandomWarp3D(p=1.), Crop], split_idx = 0),
Pipeline([RandomDihedral3D(p=1.), Crop], split_idx = 0),
Pipeline([RandomNoise3D(p=1.), Crop], split_idx = 0),
Pipeline([RandomRotate3DBy(p=1.), Crop], split_idx = 0)]
comp = setup_aug_tfms(tfms)
comp
ims = [t(im).squeeze() for t in tfms]
torch.stack(ims).show(nrow = 6)
@patch
def make_pseudo_color(t: (TensorDicom3D, TensorMask3D)):
'''
The 3D CNN still expects color images, so a pseudo color image needs to be created as long as I don't adapt the 3D CNN
'''
if t.ndim == 3:
return t.unsqueeze(0).float()
elif t.ndim == 4:
return t.unsqueeze(1).float()
else:
return t
class PseudoColor(RandTransform):
split_idx, p = None, 1
def __init__(self, p=1):
super().__init__(p=p)
def __call__(self, b, split_idx=None, **kwargs):
"change in __call__ to enforce, that the Transform is always applied on every dataset. "
return super().__call__(b, split_idx=split_idx, **kwargs)
def encodes(self, x:(TensorDicom3D, TensorMask3D)):
return x.make_pseudo_color()
MakeColor = PseudoColor()
im.shape, MakeColor(im, split_idx = 0).shape
tmp = Pipeline(aug_transforms_3d(p_all = 1.), split_idx=0)(im)
print(tmp.size())
tmp.show()
mask.reduce_classes([1]).unique()